{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b9a79366-17cb-48ef-8b67-1e6acda8f3b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import MllamaForConditionalGeneration, AutoProcessor\n",
    "from transformers import BitsAndBytesConfig\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f30711a6-ac24-48a8-bbf9-bea54ea3d209",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model_id = 'meta-llama/Llama-3.2-11B-Vision-Instruct'\n",
    "model_id = 'Llama-3.2-11B-Vision-Instruct'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f0993e5-e231-4dd1-b0e7-f6b1c06984b5",
   "metadata": {},
   "source": [
    "Run only one of the following 2 cells..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d77e2716-c8ba-438d-8283-90f9c36c4687",
   "metadata": {},
   "outputs": [],
   "source": [
    "# nf4 quant\n",
    "dtype = torch.bfloat16\n",
    "q_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    bnb_4bit_use_double_quant=False,\n",
    "    bnb_4bit_quant_type='nf4'\n",
    ")\n",
    "\n",
    "dest = 'Llama-3.2-11B-Vision-Instruct-bnb-nf4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6acebbe-82d9-4274-b02f-0c97bb355991",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LLM.int() quant\n",
    "dtype = torch.float16\n",
    "q_config = BitsAndBytesConfig(\n",
    "    load_in_8bit=True,\n",
    "    llm_int8_skip_modules=[\n",
    "        # More of a \"feel good\" thing to skip quantizing embedding\n",
    "        # & lm_head layers.\n",
    "        'vision_model.patch_embedding',\n",
    "        'vision_model.gated_positional_embedding',\n",
    "        'vision_model.gated_positional_embedding.tile_embedding',\n",
    "        'vision_model.pre_tile_positional_embedding',\n",
    "        'vision_model.pre_tile_positional_embedding.embedding',\n",
    "        'vision_model.post_tile_positional_embedding',\n",
    "        'vision_model.post_tile_positional_embedding.embedding',\n",
    "        'language_model.model.embed_tokens',\n",
    "        'language_model.lm_head',\n",
    "        # Quantizing the following leads to CUDA assertion errors during\n",
    "        # inference, so skip it.\n",
    "        'multi_modal_projector'\n",
    "    ]\n",
    ")\n",
    "\n",
    "dest = 'Llama-3.2-11B-Vision-Instruct-bnb-int8'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6500b2e1-3d15-4de9-b332-d87a756f47e1",
   "metadata": {},
   "source": [
    "Unfortunately, you have to be able to load the entire quantized model into VRAM.\n",
    "\n",
    "`nf4` will barely fit into 8 GB (but you won't be able to run inference on only 8!)\n",
    "\n",
    "There's a way around this involving a custom `device_map` and setting `llm_int8_enable_fp32_cpu_offload=True` (yes, even on the `nf4` quant), but that would make the quantized model only loadable with that specific `device_map` and config."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "48bfa6d2-85cd-446b-966c-007f0cf3b33d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "166d4bb853ff42c0ba7a767eb85f515f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = MllamaForConditionalGeneration.from_pretrained(\n",
    "    model_id,\n",
    "    # This doesn't set the dtype of the submodules (vision_model, language_model). Does it matter?\n",
    "    torch_dtype=dtype,\n",
    "    device_map='auto',\n",
    "    quantization_config=q_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2f1ae14-5b25-4859-a41f-a3c19cf3805c",
   "metadata": {},
   "source": [
    "The following should save\n",
    "\n",
    "* `config.json`\n",
    "* `generation_config.json`\n",
    "\n",
    "As well as the safetensors and `model.safetensors.index.json` file (if sharded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6b3e0dfe-8168-43d3-919f-6d223163004d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(dest)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd038407-ce74-4518-a2e6-e5ef3d815f45",
   "metadata": {},
   "source": [
    "And the following should save:\n",
    "\n",
    "* `chat_template.json`\n",
    "* `preprocessor_config.json`\n",
    "* `special_tokens_map.json`\n",
    "* `tokenizer.json`\n",
    "* `tokenizer_config.json`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a06f6856-9210-4337-a31a-b63f4c9bc0bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processor = AutoProcessor.from_pretrained(model_id)\n",
    "processor.save_pretrained(dest)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
